#!/usr/bin/env python
# Copyright (c) 2016, Konstantinos Kamnitsas
# All rights reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the BSD license. See the accompanying LICENSE file
# or read the terms at https://opensource.org/licenses/BSD-3-Clause.

from __future__ import absolute_import, print_function, division
import sys
import os
import argparse
import traceback

sys.setrecursionlimit(20000)

from deepmedic.frontEnd.configParsing.utils import abs_from_rel_path

from deepmedic.frontEnd.configParsing.modelConfig import ModelConfig
from deepmedic.frontEnd.configParsing.trainConfig import TrainConfig
from deepmedic.frontEnd.configParsing.testConfig import TestConfig

from deepmedic.frontEnd.trainSession import TrainSession
from deepmedic.frontEnd.testSession import TestSession

from deepmedic.frontEnd.configParsing.modelParams import ModelParameters

from tensorflow.python.client import device_lib


OPT_MODEL = "-model"
OPT_TRAIN = "-train"
OPT_TEST = "-test"
OPT_LOAD = "-load"

OPT_DEVICE = "-dev"
ARG_CPU_PROC = "cpu"
ARG_GPU_PROC = "cuda"
DEF_DEV_PROC = ARG_CPU_PROC

OPT_RESET = "-resetopt"


def str_is_int(s):
    try: 
        int(s)
        return True
    except ValueError:
        return False
    
def setup_arg_parser() :
    parser = argparse.ArgumentParser( prog='DeepMedic', formatter_class=argparse.RawTextHelpFormatter,
    description="\nThis software allows creation and supervised training of 3D, multi-scale CNN models for segmentation of structures in biomedical NIFTI volumes.\n"+\
                "The project is hosted at: https://github.com/Kamnitsask/deepmedic \n"+\
                "See the documentation for details on its use.\n"+\
                "This software accompanies the research presented in:\n"+\
                "Kamnitsas et al, \"Efficient Multi-Scale 3D CNN with Fully Connected CRF for Accurate Brain Lesion Segmentation\", Biomedical Image Analysis, 2016.\n"+\
                "We hope our work aids you in your endeavours.\n"+\
                "For questions and feedback contact: konstantinos.kamnitsas12@ic.ac.uk")
    
    parser.add_argument(OPT_MODEL, dest='model_cfg', type=str, help="Specify the architecture of the model to be used, by providing a config file [MODEL_CFG].")
    parser.add_argument(OPT_TRAIN, dest='train_cfg', type=str, help="Train a model with training parameters given by specifying config file [TRAINING_CFG].\n"+\
                                                                    "This option must follow a ["+OPT_MODEL+" MODEL_CFG] option, so that architecture of the to-train model is specified.\n"+\
                                                                    "Additionally, an existing checkpoint of the model can be specified in the [TRAIN_CFG] file or by the additional option ["+OPT_LOAD+"], to continue training it.")
    parser.add_argument(OPT_TEST, dest='test_cfg', type=str, help="Test with an existing model. The testing session's parameters should be given in config file [TEST_CFG].\n"+\
                                                                    "This option must follow a ["+OPT_MODEL+" MODEL_CFG] option, so that architecture of the model is specified.\n"+\
                                                                    "Existing pretrained model can be specified in the given [TEST_CFG] file or by the additional option ["+OPT_LOAD+"].\n"+\
                                                                    "This option cannot be used in combination with ["+OPT_MODEL+"] or ["+OPT_TRAIN+"].")
    parser.add_argument(OPT_LOAD, dest='saved_model', type=str, help="The path to a saved existing checkpoint with learnt weights of the model, to train or test with.\n"+\
                                                                    "This option must follow a ["+OPT_TRAIN+"] or ["+OPT_TEST+"] option.\n"+\
                                                                    "If given, this option will override any \"model\" parameters given in the [TRAIN_CFG] or [TEST_CFG] files.")
    parser.add_argument(OPT_DEVICE, default = DEF_DEV_PROC, dest='device', type=str,  help="Specify the device to run the process on. Values: [" + ARG_CPU_PROC + "] or [" + ARG_GPU_PROC + "] (default = " + DEF_DEV_PROC + ").\n"+\
                                                                    "In the case of multiple GPUs, specify a particular GPU device with a number, in the format: " + OPT_DEVICE + " " + ARG_GPU_PROC + "0 \n"+\
                                                                    "NOTE: For GPU processing, CUDA libraries must be first added in your environment's PATH and LD_LIBRARY_PATH. See accompanying documentation.")
    parser.add_argument(OPT_RESET, dest='reset_trainer', action='store_true', help="Use optionally with a ["+OPT_TRAIN+"] command. Does not take an argument.\n"+\
                                                                    "Usage: ./deepMedicRun " + OPT_MODEL + " /path/to/model/config "+OPT_TRAIN+" /path/to/train/config "+OPT_RESET+" ...etc...\n"+\
                                                                    "Resets the model\'s optimization state before starting the training session (eg number of epochs already trained, current learning rate etc).\n"+\
                                                                    "IMPORTANT: Trainable parameters are NOT reinitialized! \n"+\
                                                                    "Useful to begin a secondary training session with new learning-rate schedule, in order to fine-tune a previously trained model (Doc., Sec. 3.2)")
    
    return parser


def check_dev_passed_correctly(devArg):
    if devArg == ARG_CPU_PROC: return
    if devArg == ARG_GPU_PROC: return
    if devArg.startswith(ARG_GPU_PROC) and str_is_int(devArg[len(ARG_GPU_PROC):]): return
    
    print(  "ERROR: Value for the [" + OPT_DEVICE + "] option was not specified correctly. Specify the device to run the process on. \n"+\
            "\tValues: [" + ARG_CPU_PROC + "] or [" + ARG_GPU_PROC + "] (Default = " + DEF_DEV_PROC + ").\n"+\
            "\tIn the case of multiple GPUs, specify a particular GPU device with a number, in the format: " + ARG_GPU_PROC + "2. Exiting.")
    exit(1)
    
    
def set_environment(dev_string):
    # Setup cpu / gpu devices.
    sess_device = None
    if dev_string == ARG_CPU_PROC:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
        sess_device = "/CPU:0"
    elif dev_string == ARG_GPU_PROC:
        sess_device = None # With None, TF will get all cuda devices and assign to the first.
    if dev_string.startswith(ARG_GPU_PROC) and str_is_int(dev_string[ len(ARG_GPU_PROC):]):
        os.environ["CUDA_VISIBLE_DEVICES"] = dev_string[len(ARG_GPU_PROC):]
        sess_device = "/device:GPU:0"
        
    return (sess_device)

#################################################
#                        MAIN                   #
#################################################
if __name__ == '__main__':
    cwd = os.getcwd()
    parser = setup_arg_parser()
    args = parser.parse_args()
    
    if len(sys.argv) == 1:
        print("For help on the usage of this program, please use the option -h."); exit(1)
        
    if not args.model_cfg :
        print("ERROR: Option ["+OPT_MODEL+"] must be specified, pointing to a [MODEL_CFG] file that describes the architecture.\n"+\
              "Please try [-h] for more information. Exiting."); exit(1)
    if not (args.train_cfg or args.test_cfg) :
        print("ERROR: One of the options must be specified:\n"+\
              "\t["+OPT_TRAIN+"] to start a training session on a model.\n"+\
              "\t["+OPT_TEST+"] to test with an existing model.\n"+\
              "Please try [-h] for more information. Exiting."); exit(1)
        
    #Preliminary checks:
    if args.test_cfg and args.train_cfg:
        print("ERROR:\t["+OPT_TEST+"] cannot be used in conjuction with ["+OPT_TRAIN+"].\n"+\
              "\tTo test with an existing network, please just specify a configuration file for the testing process, which will include a path to a trained model, or specify a model with ["+OPT_LOAD+"].. Exiting."); exit(1)
              
    if args.reset_trainer and not args.train_cfg :
        print("ERROR:\tThe option ["+OPT_RESET+"] can only be used together with the ["+OPT_TRAIN+"] option.\n\tPlease try -h for more information. Exiting."); exit(1)
        
    
    # Parse main files.
    if args.model_cfg:
        abs_path_model_cfg = abs_from_rel_path(args.model_cfg, cwd)
        model_cfg = ModelConfig(abs_path_model_cfg)
        
    # Create session.
    if args.train_cfg:
        abs_path_train_cfg = abs_from_rel_path(args.train_cfg, cwd)
        session = TrainSession(TrainConfig(abs_path_train_cfg))
    elif args.test_cfg:
        abs_path_test_cfg = abs_from_rel_path(args.test_cfg, cwd)
        session = TestSession(TestConfig(abs_path_test_cfg))
        
    #Create output folders and logger.
    session.make_output_folders()
    session.setup_logger()
    
    log = session.get_logger()
    
    log.print3("")
    log.print3("======================== Starting new session ============================")
    log.print3("Command line arguments given: \n" + str(args))
    
    check_dev_passed_correctly(args.device)
    (sess_device) = set_environment(args.device)
    log.print3("Available devices to Tensorflow:\n" + str(device_lib.list_local_devices()))
    
    try:
        #Find out what session we are being asked to perform:
        if args.model_cfg: # Should be always true.
            log.print3("CONFIG: The configuration file for the [model] given is: " + str(model_cfg.get_abs_path_to_cfg()))
            model_params = ModelParameters(log, model_cfg)
            model_params.print_params()
            
        # Sessions
        log.print3("CONFIG: The configuration file for the [session] was loaded from: " + str(session.get_abs_path_to_cfg()))
        session.override_file_cfg_with_cmd_line_cfg(args)
        _ = session.compile_session_params_from_cfg(model_params)
        
        if args.train_cfg:
            session.run_session(sess_device, model_params, args.reset_trainer)
        elif args.test_cfg:
            session.run_session(sess_device, model_params)
        # All done.
    except (Exception, KeyboardInterrupt) as e:
        log.print3("")
        log.print3("ERROR: Caught exception from main process: " + str(e))
        log.print3(traceback.format_exc())
        
    log.print3("Finished.")
    

